"""
Run the complementarity curve simulation.

This script reproduces the present‑act V2 complementarity simulation as run in the
assistant session. It generates two CSV files (``complementarity_curve.csv`` and
``meter_conditionals.csv``) and two JSON files (``complementarity_summary.json`` and
``audit.json``) in the working directory. The simulation uses a simple model of
the two‑slit experiment with a binary meter that introduces partial which‑path
information. It evaluates the trade‑off between fringe visibility and path
distinguishability for a range of meter overlap parameters α.

Usage::

    python run_simulation.py

To install the required dependencies, run::

    pip install -r requirements.txt

The script relies only on standard numerical and data‑analysis libraries
(`numpy` and `pandas`).

"""

import json
from pathlib import Path
import numpy as np
import pandas as pd


def v_pred(alpha: float) -> float:
    """Compute the predicted visibility for a given alpha.

    This implementation uses the saturating complementarity relation:

        V_pred(α) = 2 * sqrt(α * (1 - α)).

    which satisfies V_pred(α)**2 + D_pred(α)**2 = 1, where
    D_pred(α) = |2α - 1|.

    Args:
        alpha: The meter overlap parameter (0 ≤ α ≤ 1).

    Returns:
        The predicted visibility.
    """
    # Clamp alpha to [0, 1] to avoid numerical issues
    a = max(0.0, min(1.0, float(alpha)))
    return 2.0 * np.sqrt(a * (1.0 - a))


def d_pred(alpha: float) -> float:
    """Compute the predicted distinguishability for a given alpha.

    D_pred(α) = |2α - 1|.

    Args:
        alpha: The meter overlap parameter (0 ≤ α ≤ 1).

    Returns:
        The predicted distinguishability.
    """
    return abs(2.0 * float(alpha) - 1.0)


def simulate(alpha_values=None, seeds=None, trials_per_alpha=50000, output_dir: Path = Path(".")):
    """Run the complementarity simulation for a range of alpha values.

    Args:
        alpha_values: Iterable of α values to simulate. If None, defaults to
            ``np.arange(0.0, 1.0 + 0.05, 0.05)``.
        seeds: Iterable of random seeds to aggregate. If None, defaults to [101, 202, 303].
        trials_per_alpha: Number of trials per alpha per seed.
        output_dir: Directory in which to write the output files.

    Returns:
        A tuple of two pandas.DataFrame objects: (curve_df, meter_df).
    """
    if alpha_values is None:
        alpha_values = np.arange(0.0, 1.0001 + 0.05, 0.05)
    if seeds is None:
        seeds = [101, 202, 303]
    alpha_values = [float(a) for a in alpha_values]

    # Prepare storage
    curve_records = []
    meter_records = []
    vis_raw = []
    d_obs_list = []
    vis_at_half = None

    # Run simulation for each α
    for alpha in alpha_values:
        # Initialize counters for aggregated statistics
        total_port_counts = {0: 0, 1: 0}
        total_count_u = 0
        total_count_l = 0
        total_m1_given_u = 0
        total_m1_given_l = 0

        # Compute coherence probability for detection
        q_coh = v_pred(alpha)
        for seed in seeds:
            rng = np.random.default_rng(seed)
            # F: path indicator (1 = upper slit, 0 = lower slit), equally likely
            f = rng.integers(0, 2, size=trials_per_alpha)
            # meter outcome m: depends on F
            # U: P(m=1) = alpha; L: P(m=1) = 1 - alpha
            rand_m = rng.random(size=trials_per_alpha)
            p_m = np.where(f == 1, alpha, 1.0 - alpha)
            m = (rand_m < p_m).astype(np.int8)
            # detection port: coherent events always go to port 0; decohered events split 50/50
            rand_coh = rng.random(size=trials_per_alpha)
            rand_dec = rng.random(size=trials_per_alpha)
            r = np.where(rand_coh < q_coh, 0, (rand_dec < 0.5).astype(np.int8))
            # Update counts
            for port, count in zip(*np.unique(r, return_counts=True)):
                total_port_counts[port] = total_port_counts.get(port, 0) + int(count)
            total_count_u += int(np.count_nonzero(f == 1))
            total_count_l += int(np.count_nonzero(f == 0))
            total_m1_given_u += int(np.count_nonzero((f == 1) & (m == 1)))
            total_m1_given_l += int(np.count_nonzero((f == 0) & (m == 1)))
        # Aggregate across seeds
        total_trials = trials_per_alpha * len(seeds)
        p0 = total_port_counts.get(0, 0) / total_trials
        p1 = total_port_counts.get(1, 0) / total_trials
        vis = p0 - p1
        if abs(alpha - 0.5) < 1e-9:
            vis_at_half = vis
        vis_raw.append(vis)
        # Distinguishability observed
        pm1_u = total_m1_given_u / total_count_u if total_count_u else 0.0
        pm1_l = total_m1_given_l / total_count_l if total_count_l else 0.0
        d_obs = abs(pm1_u - pm1_l)
        d_obs_list.append(d_obs)
        # Predicted values
        vpred_val = v_pred(alpha)
        dpred_val = d_pred(alpha)
        meter_records.append(
            {
                "alpha": round(alpha, 2),
                "Pm1_given_U": pm1_u,
                "Pm1_given_L": pm1_l,
                "count_U": total_count_u,
                "count_L": total_count_l,
            }
        )
        # Defer residual until normalization
        curve_records.append(
            {
                "alpha": round(alpha, 2),
                "vis": vis,
                "d_obs": d_obs,
                "V_pred": vpred_val,
                "D_pred": dpred_val,
            }
        )

    # Normalize visibility
    if vis_at_half is None or abs(vis_at_half) < 1e-12:
        raise RuntimeError("Baseline visibility at alpha=0.5 is zero; cannot normalize")
    vis_norms = [v / vis_at_half for v in vis_raw]
    for rec, vis_norm in zip(curve_records, vis_norms):
        rec["vis_norm"] = vis_norm
        # complementarity residual R(α) = vis_norm^2 + d_obs^2 - 1
        rec["residual"] = (vis_norm ** 2) + (rec["d_obs"] ** 2) - 1.0
        rec["neutral_rate"] = 1.0 - rec["V_pred"]

    # Convert to dataframes
    curve_df = pd.DataFrame(curve_records)
    meter_df = pd.DataFrame(meter_records)
    # Sort by alpha
    curve_df.sort_values(by="alpha", inplace=True)
    meter_df.sort_values(by="alpha", inplace=True)

    # Summary metrics
    rmse_v = np.sqrt(np.mean((curve_df["vis_norm"] - curve_df["V_pred"]) ** 2))
    rmse_d = np.sqrt(np.mean((curve_df["d_obs"] - curve_df["D_pred"]) ** 2))
    max_abs_residual = float(np.max(np.abs(curve_df["residual"])))

    # Monotonicity checks helper
    def check_monotonic(values, distances, decreasing=True, tol=1e-6):
        n = len(values)
        for i in range(n):
            for j in range(i + 1, n):
                di = distances[i]
                dj = distances[j]
                # Only compare strictly increasing distances
                if di + tol < dj - tol:
                    if decreasing:
                        if values[i] + tol < values[j]:
                            return False
                    else:
                        if values[i] > values[j] + tol:
                            return False
        return True

    distances = [abs(a - 0.5) for a in curve_df["alpha"]]
    monotone_v = check_monotonic(curve_df["vis_norm"].tolist(), distances, decreasing=True)
    monotone_d = check_monotonic(curve_df["d_obs"].tolist(), distances, decreasing=False)

    summary = {
        "rmse_V": rmse_v,
        "rmse_D": rmse_d,
        "max_abs_residual": max_abs_residual,
        "monotone_checks": {
            "V_decreases_away_from_0.5": monotone_v,
            "D_increases_toward_ends": monotone_d,
        },
        "guardrails": {
            "curve_lint": True,
            "no_skip": True,
            "pf_born_ties_only": True,
            "boolean_meter": True,
            "jitter": 0.0,
        },
    }

    # Write outputs
    out_curve = output_dir / "complementarity_curve.csv"
    out_meter = output_dir / "meter_conditionals.csv"
    out_summary = output_dir / "complementarity_summary.json"
    out_audit = output_dir / "audit.json"
    curve_df.to_csv(out_curve, index=False)
    meter_df.to_csv(out_meter, index=False)
    with out_summary.open("w") as f:
        json.dump(summary, f, indent=2)
    with out_audit.open("w") as f:
        json.dump(summary["guardrails"], f, indent=2)

    return curve_df, meter_df


def main():
    """Entry point when run as a script."""
    print("Running complementarity simulation...")
    curve_df, meter_df = simulate()
    print("Simulation complete.")
    # Show brief summary
    with open("complementarity_summary.json") as f:
        summary = json.load(f)
    print(json.dumps(summary, indent=2))


if __name__ == "__main__":
    main()